{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pairplot with GMM: Visualizing High Dimensional Data and Clustering\n",
    "\n",
    "This example provides how to visualize high dimensional data using the pairplot_with_gmm function. This function is specifcally useful because if we would like to cluster or learn about the geometry of high dimensional data, then we can call this function and utilize its ability to do pairwise comparisons across every dimension of that data as well as GMM clustering with ellipses plotted from the GMM covarariances. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graspologic\n",
    "\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate a binary graph using stochastic block model\n",
    "The 3-block model is defined as below in the next cell. \n",
    "\n",
    "Thus, the first 50 vertices belong to block 1, the second 50 vertices belong to block 2, and the last 50 vertices belong to block 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.simulations import sbm\n",
    "\n",
    "n_communities = [50, 50, 50]\n",
    "p = [[0.5, 0.1, 0.05], \n",
    "     [0.1, 0.4, 0.15], \n",
    "     [0.05, 0.15, 0.3],]\n",
    "\n",
    "np.random.seed(2)\n",
    "A = sbm(n_communities, p)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Embed using adjacency spectral embedding to obtain lower dimensional representation of the graph\n",
    "\n",
    "The embedding dimension is automatically chosen. It should embed to 3 dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from graspologic.embed import AdjacencySpectralEmbed\n",
    "\n",
    "ase = AdjacencySpectralEmbed()\n",
    "X = ase.fit_transform(A)\n",
    "\n",
    "print(X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Use pairplot_with_gmm to plot the embedded data with ellipses\n",
    "\n",
    "First we generate labels that correspond to blocks. We pass the labels along with the data for to create a pairplot with the ellipses estimated by the GMM over each compnonent of the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from graspologic.plot import pairplot_with_gmm\n",
    "from sklearn.mixture import GaussianMixture\n",
    "gmm = GaussianMixture(n_components=3, covariance_type='full').fit(X)\n",
    "graph = pairplot_with_gmm(X, gmm, labels = None, cluster_palette = None,  label_palette = None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to label your ellipses a specific color you can do the following. Please note that since we are assigning cluster assignment numbers (in this case 0, 1, 2) that the following plot will not necessarily have cluster colors matching those of the underlying scatter points. This is due to the fact that GMM does not label clusters in any particular order hence causing a permutations of colors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from sklearn.mixture import GaussianMixture\n",
    "gmm = GaussianMixture(n_components=3, covariance_type='full').fit(X)\n",
    "labels = ['Block 1'] * 50 + ['Block 2'] * 50 + ['Block 3'] * 50\n",
    "#we need 2*n_components to get two differnt colors for labels and clusters but from same palette\n",
    "color = sns.color_palette(\"Set1\", 2*3)\n",
    "cluster_palette = {0: color[0], 1:color[1], 2: color[2]}\n",
    "label_palette = {\"Block 1\": color[3], 'Block 2':color[4], 'Block 3': color[5]}\n",
    "graph = pairplot_with_gmm(X, gmm, labels = labels, cluster_palette = cluster_palette,  label_palette = label_palette)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
